{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Multi-Head Attention Debugging ====\n",
      "\n",
      "Step 1: Inputs to MHA\n",
      "  Query Shape: torch.Size([1, 32, 768])\n",
      "  Key Shape: torch.Size([1, 32, 768])\n",
      "  Value Shape: torch.Size([1, 32, 768])\n",
      "\n",
      "Step 2: Split into 12 Heads\n",
      "  Query Split Shape: torch.Size([1, 12, 32, 64])\n",
      "  Key Split Shape: torch.Size([1, 12, 32, 64])\n",
      "  Value Split Shape: torch.Size([1, 12, 32, 64])\n",
      "\n",
      "Step 3: Compute Attention Scores (QK^T / sqrt(d_h))\n",
      "  Attention Scores Shape: torch.Size([1, 12, 32, 32])\n",
      "Step 4: Apply Softmax to Get Attention Weights\n",
      "  Attention Weights Shape: torch.Size([1, 12, 32, 32])\n",
      "\n",
      "Step 5: Compute SV (Attention Weights * V)\n",
      "  SV (Attention Output Before Merge) Shape: torch.Size([1, 12, 32, 64])\n",
      "\n",
      "Step 6: Merge Heads Back\n",
      "  Final MHA Output Shape: torch.Size([1, 32, 768])\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[[0.4388, 0.5095, 0.5035,  ..., 0.5491, 0.5187, 0.5283],\n",
       "         [0.4288, 0.4965, 0.5083,  ..., 0.5577, 0.5238, 0.5086],\n",
       "         [0.4437, 0.5094, 0.5236,  ..., 0.5610, 0.5125, 0.5377],\n",
       "         ...,\n",
       "         [0.4654, 0.5045, 0.5149,  ..., 0.5473, 0.5233, 0.5203],\n",
       "         [0.4599, 0.5065, 0.5262,  ..., 0.5363, 0.5304, 0.5120],\n",
       "         [0.4452, 0.5029, 0.5358,  ..., 0.5510, 0.5219, 0.5326]]])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import DistilBertModel\n",
    "\n",
    "# Load pretrained DistilBERT model\n",
    "model = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n",
    "\n",
    "# Extract the first transformer's multi-head attention layer\n",
    "layer = model.transformer.layer[0]\n",
    "attention_layer = layer.attention\n",
    "\n",
    "# Override the forward function to print intermediate steps\n",
    "def debug_mha_forward(self, query, key, value, mask=None, head_mask=None):\n",
    "    print(\"\\n==== Multi-Head Attention Debugging ====\\n\")\n",
    "\n",
    "    # Step 1: Print Input Shapes\n",
    "    print(f\"Step 1: Inputs to MHA\")\n",
    "    print(f\"  Query Shape: {query.shape}\")  # Expected: (batch, seq_len, hidden_size)\n",
    "    print(f\"  Key Shape: {key.shape}\")\n",
    "    print(f\"  Value Shape: {value.shape}\\n\")\n",
    "\n",
    "    batch_size, seq_length, hidden_size = query.shape\n",
    "\n",
    "    # Step 2: Split into multi-heads\n",
    "    def split_heads(tensor, num_heads, head_dim):\n",
    "        return tensor.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)\n",
    "\n",
    "    num_heads = self.n_heads\n",
    "    head_dim = hidden_size // num_heads\n",
    "\n",
    "    query = split_heads(query, num_heads, head_dim)\n",
    "    key = split_heads(key, num_heads, head_dim)\n",
    "    value = split_heads(value, num_heads, head_dim)\n",
    "\n",
    "    print(f\"Step 2: Split into {num_heads} Heads\")\n",
    "    print(f\"  Query Split Shape: {query.shape}\")  # (batch, num_heads, seq_len, head_dim)\n",
    "    print(f\"  Key Split Shape: {key.shape}\")\n",
    "    print(f\"  Value Split Shape: {value.shape}\\n\")\n",
    "\n",
    "    # Step 3: Compute Scaled Dot-Product Attention\n",
    "    attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (head_dim ** 0.5)\n",
    "    print(f\"Step 3: Compute Attention Scores (QK^T / sqrt(d_h))\")\n",
    "    print(f\"  Attention Scores Shape: {attention_scores.shape}\")  # (batch, num_heads, seq_len, seq_len)\\n\n",
    "\n",
    "    if mask is not None:\n",
    "        attention_scores = attention_scores.masked_fill(mask == 0, float(\"-inf\"))\n",
    "\n",
    "    attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)\n",
    "    print(f\"Step 4: Apply Softmax to Get Attention Weights\")\n",
    "    print(f\"  Attention Weights Shape: {attention_probs.shape}\\n\")  # (batch, num_heads, seq_len, seq_len)\n",
    "\n",
    "    # Step 5: Compute SV = Attention Weights * V\n",
    "    attention_output = torch.matmul(attention_probs, value)\n",
    "    print(f\"Step 5: Compute SV (Attention Weights * V)\")\n",
    "    print(f\"  SV (Attention Output Before Merge) Shape: {attention_output.shape}\\n\")  # (batch, num_heads, seq_len, head_dim)\n",
    "\n",
    "    # Step 6: Merge Heads Back\n",
    "    attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_length, hidden_size)\n",
    "    print(f\"Step 6: Merge Heads Back\")\n",
    "    print(f\"  Final MHA Output Shape: {attention_output.shape}\\n\")  # (batch, seq_len, hidden_size)\n",
    "\n",
    "    return attention_output\n",
    "\n",
    "# Attach the debug function to the attention layer\n",
    "attention_layer.sa_forward = debug_mha_forward.__get__(attention_layer)\n",
    "\n",
    "# Create dummy input\n",
    "input_tensor = torch.rand(1, 32, 768)  # (batch_size=1, seq_len=64, hidden_size=768)\n",
    "\n",
    "# Run forward pass and print all steps\n",
    "attention_layer.sa_forward(input_tensor, input_tensor, input_tensor)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "==== Checking q_lin ====\n",
      "Input Shape: torch.Size([1, 32, 768])\n",
      "Weight Shape: torch.Size([768, 768])\n",
      "Bias Shape: torch.Size([768])\n",
      "Output Shape: torch.Size([1, 32, 768])\n",
      "\n",
      "==== Checking k_lin ====\n",
      "Input Shape: torch.Size([1, 32, 768])\n",
      "Weight Shape: torch.Size([768, 768])\n",
      "Bias Shape: torch.Size([768])\n",
      "Output Shape: torch.Size([1, 32, 768])\n",
      "\n",
      "==== Checking v_lin ====\n",
      "Input Shape: torch.Size([1, 32, 768])\n",
      "Weight Shape: torch.Size([768, 768])\n",
      "Bias Shape: torch.Size([768])\n",
      "Output Shape: torch.Size([1, 32, 768])\n",
      "\n",
      "==== Checking out_lin ====\n",
      "Input Shape: torch.Size([1, 32, 768])\n",
      "Weight Shape: torch.Size([768, 768])\n",
      "Bias Shape: torch.Size([768])\n",
      "Output Shape: torch.Size([1, 32, 768])\n",
      "\n",
      "✅ Q, K, V, O, projections produce consistent shapes.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import DistilBertModel\n",
    "\n",
    "# Load DistilBERT model (pretrained)\n",
    "model = DistilBertModel.from_pretrained(\"distilbert-base-uncased\")\n",
    "\n",
    "# Define input tensor (batch_size=1, sequence_length=64, hidden_size=768)\n",
    "batch_size = 1\n",
    "sequence_length = 32\n",
    "hidden_size = 768\n",
    "\n",
    "input_tensor = torch.rand(batch_size, sequence_length, hidden_size)  # Simulated input embeddings\n",
    "\n",
    "# Extract first transformer layer\n",
    "layer = model.transformer.layer[0]\n",
    "\n",
    "# Define a helper function to track computation inside q_lin, k_lin, v_lin\n",
    "def check_projection(layer, tensor, name):\n",
    "    print(f\"\\n==== Checking {name} ====\")\n",
    "    print(f\"Input Shape: {tensor.shape}\")  # Should be (batch_size, S, d_model)\n",
    "\n",
    "    # Forward pass through the projection layer\n",
    "    output = layer.attention.__getattr__(name)(tensor)  # Equivalent to layer.attention.q_lin(tensor), etc.\n",
    "\n",
    "    print(f\"Weight Shape: {layer.attention.__getattr__(name).weight.shape}\")  # Should be (768, 768)\n",
    "    print(f\"Bias Shape: {layer.attention.__getattr__(name).bias.shape if layer.attention.__getattr__(name).bias is not None else 'None'}\")\n",
    "    print(f\"Output Shape: {output.shape}\")  # Should be (batch_size, S, d_model)\n",
    "\n",
    "    return output\n",
    "\n",
    "# Check input and output shapes for Q, K, V projections\n",
    "q_output = check_projection(layer, input_tensor, \"q_lin\")\n",
    "k_output = check_projection(layer, input_tensor, \"k_lin\")\n",
    "v_output = check_projection(layer, input_tensor, \"v_lin\")\n",
    "o_output = check_projection(layer, input_tensor, \"out_lin\")\n",
    "\n",
    "# Confirm output shape consistency\n",
    "assert q_output.shape == k_output.shape == v_output.shape == o_output.shape, \"Q, K, V, O output shapes do not match!\"\n",
    "print(\"\\n✅ Q, K, V, O, projections produce consistent shapes.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DistilBertSdpaAttention(\n",
      "  (dropout): Dropout(p=0.1, inplace=False)\n",
      "  (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
      "  (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
      "  (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
      "  (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(layer.attention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FFN(\n",
      "  (dropout): Dropout(p=0.1, inplace=False)\n",
      "  (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
      "  (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
      "  (activation): GELUActivation()\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(layer.ffn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
